import time
import numpy as np
from typing import List, Dict, Tuple
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.linear_model import LogisticRegression


# -------------------------------------------------------------------
# Basic IPS estimators
# -------------------------------------------------------------------


def naive_ips(
    pi_e_ij: np.ndarray,
    pi_ij: np.ndarray,
    r_ij: np.ndarray,
) -> float:
    """Naive IPS: E[r * pi_e / pi_b]."""
    iw = pi_e_ij / pi_ij
    iw = np.nan_to_num(iw, posinf=0.0, neginf=0.0)
    return float(np.mean(r_ij * iw))


def balanced_ips(
    pi_e_ij: np.ndarray,
    pi_i_jk: np.ndarray,
    r_ij: np.ndarray,
    n_i: np.ndarray,
) -> float:
    """Balanced IPS (Kallus et al.)."""
    # n_pi_avg_ij[j] = sum_i n_i[i] * pi_i_jk[i, j]
    n_pi_avg_ij = np.tensordot(n_i, pi_i_jk, axes=([0], [0]))
    iw = pi_e_ij / n_pi_avg_ij
    iw = np.nan_to_num(iw, posinf=0.0, neginf=0.0)
    return float(np.sum(r_ij * iw))


def weighted_ips(
    pi_e_ij: np.ndarray,
    pi_ij: np.ndarray,
    r_ij: np.ndarray,
    n_i: np.ndarray,
    stratum_idx: List[int],
    n_fold: int,
) -> float:
    """Weighted IPS with cross-fitted variance-based weights."""
    n_loggers = len(n_i)
    n_data = len(pi_e_ij)

    kf = KFold(n_splits=n_fold, shuffle=True, random_state=0)

    # Each entry: (divs, vals, fold_sizes)
    fold_data: List[Tuple[np.ndarray, np.ndarray, np.ndarray]] = [
        (np.zeros(n_loggers), np.zeros(n_loggers), np.zeros(n_loggers))
        for _ in range(n_fold)
    ]

    # Per logger stratified splitting
    for i in range(n_loggers):
        start, end = stratum_idx[i], stratum_idx[i + 1]
        pi_e_ij_ = pi_e_ij[start:end]
        pi_ij_ = pi_ij[start:end]
        r_ij_ = r_ij[start:end]

        for fold_idx, (train_idx, test_idx) in enumerate(kf.split(np.arange(n_i[i]))):
            divs, vals, fold_sizes = fold_data[fold_idx]

            # Divergence estimator: sample variance of IPS rewards
            iw_tr = pi_e_ij_[train_idx] / pi_ij_[train_idx]
            iw_tr = np.nan_to_num(iw_tr, posinf=0.0, neginf=0.0)
            ips_tr = r_ij_[train_idx] * iw_tr
            div = np.var(ips_tr, ddof=1)
            divs[i] = div

            # Fold value contribution on test set
            iw_tst = pi_e_ij_[test_idx] / pi_ij_[test_idx]
            iw_tst = np.nan_to_num(iw_tst, posinf=0.0, neginf=0.0)
            vals[i] = np.sum(r_ij_[test_idx] * iw_tst)

            # Fold size
            fold_sizes[i] = len(test_idx)

    # Cross-fitting
    ope_val = 0.0
    for divs, vals, fold_sizes in fold_data:
        divs_safe = np.where(divs <= 0.0, 1.0, divs)
        lams = (1.0 / divs_safe) / np.sum(fold_sizes / divs_safe)
        fold_n = np.sum(fold_sizes)
        ope_val += fold_n * np.dot(vals, lams)

    return float(ope_val / n_data)


# -------------------------------------------------------------------
# DR-bIPS (optimality proved in Kallus et al.)
# -------------------------------------------------------------------


def dr_balanced_ips(
    pi_e_ij: np.ndarray,
    pi_i_jk: np.ndarray,
    r_ij: np.ndarray,
    q_ij: np.ndarray,
    q_ij_pi_e: np.ndarray,
    n_i: np.ndarray,
) -> float:
    """
    DR-pi* estimator from Kallus et al.

    pi_i_jk: shape (n_loggers, n_data)
    n_i:     per logger counts (or effective weights), shape (n_loggers,)
    """
    # n_pi_avg_ij[j] = sum_i n_i[i] * pi_i_jk[i, j]
    n_pi_avg_ij = np.tensordot(n_i, pi_i_jk, axes=([0], [0]))
    iw_bal = pi_e_ij / n_pi_avg_ij
    iw_bal = np.nan_to_num(iw_bal, posinf=0.0, neginf=0.0)

    term_ips = np.sum((r_ij - q_ij) * iw_bal)
    term_dm = np.mean(q_ij_pi_e)

    return float(term_ips + term_dm)


# -------------------------------------------------------------------
# Optimal IPS estimator (Ours)
# -------------------------------------------------------------------


def optimal_ips(
    pi_e_ij: np.ndarray,
    pi_i_jk: np.ndarray,
    r_ij: np.ndarray,
    n_i: np.ndarray,
    stratum_idx: List[int],
    n_fold: int,
) -> float:
    """Optimal IPS (Ours) with cross-fitting."""
    n_loggers = len(n_i)
    n_data = len(pi_e_ij)

    kf = KFold(n_splits=n_fold, shuffle=True, random_state=0)

    fold_data = [
        {
            "T": np.zeros((n_loggers, n_loggers)),
            "c": np.zeros(n_loggers),
            "test_idxes": [],
        }
        for _ in range(n_fold)
    ]

    # Build T and c matrices per fold and logger
    for i in range(n_loggers):
        start, end = stratum_idx[i], stratum_idx[i + 1]
        pi_e_ij_ = pi_e_ij[start:end]
        pi_i_jk_ = pi_i_jk[:, start:end]
        r_ij_ = r_ij[start:end]

        for fold_idx, (train_idx, test_idx) in enumerate(kf.split(np.arange(n_i[i]))):
            fd = fold_data[fold_idx]
            fd["test_idxes"].append(test_idx)

            pi_e_tr = pi_e_ij_[train_idx]
            pi_i_tr = pi_i_jk_[:, train_idx]
            r_tr = r_ij_[train_idx]

            denom = 1.0 / np.tensordot(n_i, pi_i_tr, axes=([0], [0]))
            # T[i, :] is size (n_loggers,)
            fd["T"][i] = (
                n_i[i]
                * np.tensordot(pi_i_tr, denom, axes=([1], [0]))
                / pi_i_tr.shape[1]
            )
            fd["c"][i] = n_i[i] * np.mean(r_tr * pi_e_tr * denom)

    # Cross-fitted OPE
    ope_val = 0.0
    for fd in fold_data:
        T = fd["T"]
        c = fd["c"]
        test_idxes = fd["test_idxes"]
        fold_sizes = np.array([len(ti) for ti in test_idxes], dtype=float)
        fold_n = np.sum(fold_sizes)

        # Solve for alpha
        alpha = np.linalg.pinv(T) @ c

        term1 = np.sum(alpha)

        # Vectorized calculation over all test samples
        global_test_idx = np.concatenate(
            [t_idx + stratum_idx[i] for i, t_idx in enumerate(test_idxes)]
        )

        pi_e_tst = pi_e_ij[global_test_idx]
        pi_i_tst = pi_i_jk[:, global_test_idx]
        r_tst = r_ij[global_test_idx]

        denom = 1.0 / np.tensordot(fold_sizes, pi_i_tst, axes=([0], [0]))
        term2 = np.sum(r_tst * pi_e_tst * denom)
        term3 = np.sum(np.tensordot(alpha, pi_i_tst, axes=([0], [0])) * denom)

        ope_val += fold_n * (term1 + term2 - term3)

    return float(ope_val / n_data)


# -------------------------------------------------------------------
# Doubly-robust extension of Optimal IPS estimator (Ours)
# -------------------------------------------------------------------


def dr_optimal_ips(
    pi_e_ij: np.ndarray,
    pi_i_jk: np.ndarray,
    r_ij: np.ndarray,
    q_ij: np.ndarray,
    q_ij_pi_e: np.ndarray,
    n_i: np.ndarray,
    stratum_idx: List[int],
    n_fold: int,
) -> float:
    """Doubly-robust Optimal IPS (Ours) with cross-fitting."""
    n_loggers = len(n_i)
    n_data = len(pi_e_ij)

    kf = KFold(n_splits=n_fold, shuffle=True, random_state=0)

    fold_data = [
        {
            "T": np.zeros((n_loggers, n_loggers)),
            "c": np.zeros(n_loggers),
            "test_idxes": [],
        }
        for _ in range(n_fold)
    ]

    # Subtract q_ij from r_ij
    r_ij = r_ij - q_ij

    # Build T and c matrices per fold and logger
    for i in range(n_loggers):
        start, end = stratum_idx[i], stratum_idx[i + 1]
        pi_e_ij_ = pi_e_ij[start:end]
        pi_i_jk_ = pi_i_jk[:, start:end]
        r_ij_ = r_ij[start:end]

        for fold_idx, (train_idx, test_idx) in enumerate(kf.split(np.arange(n_i[i]))):
            fd = fold_data[fold_idx]
            fd["test_idxes"].append(test_idx)

            pi_e_tr = pi_e_ij_[train_idx]
            pi_i_tr = pi_i_jk_[:, train_idx]
            r_tr = r_ij_[train_idx]

            denom = 1.0 / np.tensordot(n_i, pi_i_tr, axes=([0], [0]))
            # T[i, :] is size (n_loggers,)
            fd["T"][i] = (
                n_i[i]
                * np.tensordot(pi_i_tr, denom, axes=([1], [0]))
                / pi_i_tr.shape[1]
            )
            fd["c"][i] = n_i[i] * np.mean(r_tr * pi_e_tr * denom)

    # Cross-fitted OPE
    ope_val = 0.0
    for fd in fold_data:
        T = fd["T"]
        c = fd["c"]
        test_idxes = fd["test_idxes"]
        fold_sizes = np.array([len(ti) for ti in test_idxes], dtype=float)
        fold_n = np.sum(fold_sizes)

        # Solve for alpha
        alpha = np.linalg.pinv(T) @ c

        term1 = np.sum(alpha)

        # Vectorized calculation over all test samples
        global_test_idx = np.concatenate(
            [t_idx + stratum_idx[i] for i, t_idx in enumerate(test_idxes)]
        )

        pi_e_tst = pi_e_ij[global_test_idx]
        pi_i_tst = pi_i_jk[:, global_test_idx]
        r_tst = r_ij[global_test_idx]

        denom = 1.0 / np.tensordot(fold_sizes, pi_i_tst, axes=([0], [0]))
        term2 = np.sum(r_tst * pi_e_tst * denom)
        term3 = np.sum(np.tensordot(alpha, pi_i_tst, axes=([0], [0])) * denom)

        ope_val += fold_n * (term1 + term2 - term3)

    # Add back the mean of q_ij_pi_e (which is J_DM)
    return float(ope_val / n_data) + np.mean(q_ij_pi_e)


# -------------------------------------------------------------------
# Reward model (q) estimation utilities
# -------------------------------------------------------------------


def _build_action_features(
    X: np.ndarray,
    actions: np.ndarray,
    num_actions: int,
) -> np.ndarray:
    """
    Concatenate state features with one hot action encoding.

    X:       (n, d)
    actions: (n,) int in [0, num_actions)
    -> (n, d + num_actions)
    """
    n = X.shape[0]
    one_hot = np.zeros((n, num_actions), dtype=float)
    one_hot[np.arange(n), actions] = 1.0
    return np.hstack([X, one_hot])


def estimate_q_twofold(X, actions, rewards, num_actions, y_eval_pol, random_state):
    """
    Returns:
        q_ij: out-of-fold estimates q(s_j, a_j)
        q_ij_pi_e: estimates q(s_j, pi_e(s_j))
    """
    num_data = len(actions)
    q_ij = np.zeros(num_data, dtype=float)
    q_ij_pi_e = np.zeros(num_data, dtype=float)

    skf = StratifiedKFold(n_splits=2, shuffle=True, random_state=random_state)

    for train_idx, ev_idx in skf.split(X, actions):
        # Train a classifier on (s,a) → r
        X_tr = _build_action_features(X[train_idx], actions[train_idx], num_actions)
        clf = LogisticRegression(random_state=random_state, solver="lbfgs")
        clf.fit(X_tr, rewards[train_idx])

        # Predict q(s,a_logged) on held-out fold
        X_ev = _build_action_features(X[ev_idx], actions[ev_idx], num_actions)
        q_ij[ev_idx] = clf.predict_proba(X_ev)[:, 1]

        # Predict q(s, a_eval)
        X_eval = _build_action_features(
            X[ev_idx], y_eval_pol[ev_idx].astype(int), num_actions
        )
        q_ij_pi_e[ev_idx] = clf.predict_proba(X_eval)[:, 1]

    return q_ij, q_ij_pi_e


# -------------------------------------------------------------------
# Logging policy estimation (pi_b)
# -------------------------------------------------------------------


def estimate_logging_policies(
    X: np.ndarray,
    actions: np.ndarray,
    pi_e_ij: np.ndarray,
    num_actions: int,
    num_loggers: int,
    stratum_idx: List[int],
    exploration_probs: np.ndarray,
    estimate_pi_b: bool,
    random_state: int,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Estimate logging policy probabilities.

    Returns
    -------
    pi_ij : shape (num_data,)
        Probability of the logged action under its own logger policy.
    pi_i_jk : shape (num_loggers, num_data)
        For each logger i and each sample j, probability that logger i
        would have taken the logged action a_j in state s_j.
    """
    num_data = actions.shape[0]
    pi_ij = np.zeros(num_data, dtype=float)
    pi_i_jk = np.zeros((num_loggers, num_data), dtype=float)

    if not estimate_pi_b:
        # Known logging policies: epsilon mixture around evaluation policy
        for i in range(num_loggers):
            p = exploration_probs[i]
            start_i, end_i = stratum_idx[i], stratum_idx[i + 1]

            # Probability of actually logged actions in stratum i
            pi_ij[start_i:end_i] = (1.0 - p) * pi_e_ij[start_i:end_i] + (
                p / num_actions
            )

            # For all samples j, probability under logger i
            for j in range(num_loggers):
                start_j, end_j = stratum_idx[j], stratum_idx[j + 1]
                pi_i_jk[i, start_j:end_j] = (1.0 - p) * pi_e_ij[start_j:end_j] + (
                    p / num_actions
                )

        return pi_ij, pi_i_jk

    # Unknown logging policies: per logger 2-fold estimation with smoothing
    mix_eps = 0.01
    eps_floor = 1e-3
    X_all = X

    for i in range(num_loggers):
        start, end = stratum_idx[i], stratum_idx[i + 1]
        global_idx_i = np.arange(start, end)

        X_i = X_all[start:end]
        a_i = actions[start:end]

        has_two_classes = np.unique(a_i).size >= 2
        enough_samples = X_i.shape[0] >= 4

        # 1) pi_ij: out of fold probability of logged action under logger i
        if has_two_classes and enough_samples:
            skf_local = StratifiedKFold(
                n_splits=2,
                shuffle=True,
                random_state=random_state,
            )
            for tr_loc, ev_loc in skf_local.split(X_i, a_i):
                clf = LogisticRegression(
                    random_state=random_state, solver="lbfgs", multi_class="multinomial"
                )
                clf.fit(X_i[tr_loc], a_i[tr_loc])
                proba_ev = clf.predict_proba(X_i[ev_loc])

                full_ev = np.zeros((ev_loc.shape[0], num_actions), dtype=float)
                full_ev[:, clf.classes_.astype(int)] = proba_ev

                full_ev = (1.0 - mix_eps) * full_ev + (mix_eps / num_actions)
                g_ev = global_idx_i[ev_loc]
                pi_ij[g_ev] = full_ev[np.arange(ev_loc.shape[0]), a_i[ev_loc]]
        else:
            # fallback: uniform
            pi_ij[start:end] = 1.0 / num_actions

        # 2) pi_i_jk: probability of logged action for all samples under logger i
        if has_two_classes and X_i.shape[0] >= 2:
            clf_full = LogisticRegression(
                random_state=random_state, solver="lbfgs", multi_class="multinomial"
            )
            clf_full.fit(X_i, a_i)
            proba_all = clf_full.predict_proba(X_all)

            full_all = np.zeros((num_data, num_actions), dtype=float)
            full_all[:, clf_full.classes_.astype(int)] = proba_all
        else:
            full_all = np.full((num_data, num_actions), 1.0 / num_actions, dtype=float)

        full_all = (1.0 - mix_eps) * full_all + (mix_eps / num_actions)
        pi_i_jk[i, :] = full_all[np.arange(num_data), actions]

    pi_ij = np.clip(pi_ij, eps_floor, 1.0)
    pi_i_jk = np.clip(pi_i_jk, eps_floor, 1.0)

    return pi_ij, pi_i_jk


# -------------------------------------------------------------------
# Main OPE runner
# -------------------------------------------------------------------


def run_ope(ope_args: Dict) -> Dict[str, float]:
    X = ope_args["states"]
    y_gt = ope_args["y_gt"]
    y_eval_pol = ope_args["y_eval_pol"]
    num_actions = ope_args["num_actions"]
    num_loggers = ope_args["num_loggers"]
    exploration_probs = np.array(ope_args["exploration_probs"], dtype=float)
    stratum_ratio = ope_args["stratum_ratio"]
    n_fold = ope_args["n_fold"]
    estimate_pi_b = ope_args["estimate_pi_b"]
    random_state = ope_args["random_state"]

    num_data = len(y_gt)

    rng = np.random.RandomState(random_state)

    # ------------------------------------------------------------
    # Sample logging actions per stratum
    # ------------------------------------------------------------
    stratum_sizes = np.array(num_data * np.array(stratum_ratio), dtype=np.int32)
    stratum_idx: List[int] = [0]
    for i in range(num_loggers - 1):
        stratum_idx.append(int(stratum_sizes[: i + 1].sum()))
    stratum_idx.append(num_data)

    y_logging_pol = y_eval_pol.copy()
    for i in range(num_loggers):
        start, end = stratum_idx[i], stratum_idx[i + 1]
        y_logging_i = y_logging_pol[start:end]
        p = exploration_probs[i]
        resample_mask = rng.uniform(size=len(y_logging_i)) <= p
        y_logging_i[resample_mask] = rng.choice(
            num_actions,
            size=resample_mask.sum(),
        )
        stratum_sizes[i] = len(y_logging_i)

    # ------------------------------------------------------------
    # pi_e_ij, logging policies, and rewards
    # ------------------------------------------------------------
    pi_e_ij = (y_logging_pol == y_eval_pol).astype(float)

    pi_ij, pi_i_jk = estimate_logging_policies(
        X=X,
        actions=y_logging_pol.astype(int),
        pi_e_ij=pi_e_ij,
        num_actions=num_actions,
        num_loggers=num_loggers,
        stratum_idx=stratum_idx,
        exploration_probs=exploration_probs,
        estimate_pi_b=estimate_pi_b,
        random_state=random_state,
    )

    r_ij = (y_gt == y_logging_pol).astype(float)

    # ------------------------------------------------------------
    # Reward estimates for doubly robust methods
    # ------------------------------------------------------------
    q_ij, q_ij_pi_e = estimate_q_twofold(
        X=X,
        actions=y_logging_pol.astype(int),
        rewards=r_ij,
        y_eval_pol=y_eval_pol,
        num_actions=num_actions,
        random_state=random_state,
    )

    # ------------------------------------------------------------
    # Collect estimators
    # ------------------------------------------------------------
    result: Dict[str, float] = {}

    start_time = time.time()
    result["J_naive_ips"] = naive_ips(
        pi_e_ij=pi_e_ij,
        pi_ij=pi_ij,
        r_ij=r_ij,
    )
    result["time_J_naive_ips"] = time.time() - start_time

    start_time = time.time()
    result["J_balanced_ips"] = balanced_ips(
        pi_e_ij=pi_e_ij,
        pi_i_jk=pi_i_jk,
        r_ij=r_ij,
        n_i=stratum_sizes,
    )
    result["time_J_balanced_ips"] = time.time() - start_time

    start_time = time.time()
    result["J_weighted_ips"] = weighted_ips(
        pi_e_ij=pi_e_ij,
        pi_ij=pi_ij,
        r_ij=r_ij,
        n_i=stratum_sizes,
        stratum_idx=stratum_idx,
        n_fold=n_fold,
    )
    result["time_J_weighted_ips"] = time.time() - start_time

    start_time = time.time()
    result["J_dr_balanced_ips"] = dr_balanced_ips(
        pi_e_ij=pi_e_ij,
        pi_i_jk=pi_i_jk,
        r_ij=r_ij,
        q_ij=q_ij,
        q_ij_pi_e=q_ij_pi_e,
        n_i=stratum_sizes,
    )
    result["time_J_dr_balanced_ips"] = time.time() - start_time

    start_time = time.time()
    result["J_optimal_ips"] = optimal_ips(
        pi_e_ij=pi_e_ij,
        pi_i_jk=pi_i_jk,
        r_ij=r_ij,
        n_i=stratum_sizes,
        stratum_idx=stratum_idx,
        n_fold=n_fold,
    )
    result["time_J_optimal_ips"] = time.time() - start_time

    start_time = time.time()
    result["J_dr_optimal_ips"] = dr_optimal_ips(
        pi_e_ij=pi_e_ij,
        pi_i_jk=pi_i_jk,
        r_ij=r_ij,
        q_ij=q_ij,
        q_ij_pi_e=q_ij_pi_e,
        n_i=stratum_sizes,
        stratum_idx=stratum_idx,
        n_fold=n_fold,
    )
    result["time_J_dr_optimal_ips"] = time.time() - start_time

    return result
